package com.region.loadbalancer.policy;

import com.region.loadbalancer.group.Server;
import com.region.loadbalancer.group.ServerState;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ThreadLocalRandom;

/**
 * The abstract weight policy
 *
 * @author liujieyu
 * @date 2023/5/27 16:32
 * @desciption
 */
public abstract class WeightPolicy implements AllServerBalancerPolicy {

    protected List<Integer> realWeight;

    protected List<Integer> originWeight = new CopyOnWriteArrayList<>();

    protected int multiple = 10;

    protected double differ = 0.1;

    private ThreadLocalRandom random = ThreadLocalRandom.current();

    /**
     * Set all weight values
     *
     * @throws UnsupportedOperationException
     */
    protected abstract void setAllWeight(List<Server> servers) throws UnsupportedOperationException;

    protected abstract int getWeightByIndex(int index) throws UnsupportedOperationException;

    @Override
    public Server choose(List<Server> servers) {
        setAllWeight(servers);
        recalculateWeight(servers);
        Server selectServer = null;
        if (useRandomRatio(servers)) {
            int sum = 0;
            for (int i = 0; i < servers.size(); i++) {
                if (servers.get(i).getStatus() == ServerState.UP) {
                    sum += originWeight.get(i).intValue();
                }
            }
            int nextInt = random.nextInt(sum);
            int chance = 0;
            for (int i = 0; i < servers.size(); i++) {
                if (servers.get(i).getStatus() == ServerState.UP) {
                    chance += originWeight.get(i).intValue();
                    if (nextInt < chance) {
                        selectServer = servers.get(i);
                        break;
                    }
                }
            }
        } else {
            // save base of weight
            double base = -1;
            // save min difference weight
            double diff = 0.0f;
            for (int i = 0; i < servers.size(); i++) {
                if (servers.get(i).getStatus() == ServerState.UP) {
                    if (base == -1) {
                        base = 1.0f * servers.get(i).getCount().get() / originWeight.get(i);
                        selectServer = servers.get(i);
                    }
                    double otherBase = 1.0f * servers.get(i).getCount().get() / originWeight.get(i);
                    if ((otherBase - base) < diff) {
                        diff = otherBase - base;
                        selectServer = servers.get(i);
                    }
                }
            }
        }
        if (selectServer == null) {
            for (int i = 0; i < servers.size(); i++) {
                if (servers.get(i).getStatus() == ServerState.UP) {
                    selectServer = servers.get(i);
                    break;
                }
            }
        }
        return selectServer;
    }

    private boolean useRandomRatio(List<Server> servers) {
        for (int i = 0; i < servers.size(); i++) {
            if (originWeight.get(i).intValue() != getWeightByIndex(i)) {
                return true;
            }
        }
        return false;
    }

    /**
     * If a node recovers after an exception, the weight must be recalculated
     * @param servers
     */
    private void recalculateWeight(List<Server> servers) {
        if (originWeight.size() != servers.size()) {
            for (int i = 0; i < servers.size(); i++) {
                originWeight.add(i, getWeightByIndex(i));
            }
        }
        // If the number of times exceeds the specified ratio, the weight must be recalculated
        if (isRecalculate(servers)) {
            long maxWeight = getMaxWeight(servers);
            for (int i = 0; i < servers.size(); i++) {
                double scale = (1 - 1.0 * servers.get(i).getCount().get() / (maxWeight * getWeightByIndex(i)));
                if (scale > differ) {
                    originWeight.set(i, (int) Math.ceil(getWeightByIndex(i) + 2 * (getWeightByIndex(i) * scale)));
                } else {
                    originWeight.set(i, getWeightByIndex(i));
                }
            }
        }
    }

    /**
     *
     * @param servers
     * @return
     */
    private boolean isRecalculate(List<Server> servers) {
        for (int i = 0; i < servers.size(); i++) {
            long scale = (servers.get(i).getCount().get() / getWeightByIndex(i));
            if (scale < multiple && servers.get(i).getStatus() == ServerState.DOWN) {
                continue;
            }
            return scale > multiple;
        }
        return false;
    }

    private long getMaxWeight(List<Server> servers) {
        long max =  -1;
        for (int i = 0; i < servers.size(); i++) {
            long getWeight = servers.get(i).getCount().get() / getWeightByIndex(i);
            if (getWeight > max) {
                max = getWeight;
            }
        }
        return max;
    }

}
